import cv2
import time
import numpy as np
from random import randint
from rknn.api import RKNN

rknn = RKNN()
'''
protoFile = "pose/coco/pose_deploy_linevec.prototxt"
weightsFile = "pose/coco/pose_iter_440000.caffemodel"
'''
# nPoints = 18
# COCO Output Format
# keypointsMapping = ['Nose', 'Neck', 'R-Sho', 'R-Elb', 'R-Wr', 'L-Sho', 'L-Elb', 'L-Wr', 'R-Hip', 'R-Knee', 'R-Ank', 'L-Hip', 'L-Knee', 'L-Ank', 'R-Eye', 'L-Eye', 'R-Ear', 'L-Ear']

# POSE_PAIRS = [[1,2], [1,5], [2,3], [3,4], [5,6], [6,7],
#               [1,8], [8,9], [9,10], [1,11], [11,12], [12,13],
#               [1,0], [0,14], [14,16], [0,15], [15,17],
#               [2,16], [5,17] ]

# # index of pafs correspoding to the POSE_PAIRS
# # e.g for POSE_PAIR(1,2), the PAFs are located at indices (31,32) of output, Similarly, (1,5) -> (39,40) and so on.
# mapIdx = [[31,32], [39,40], [33,34], [35,36], [41,42], [43,44],
#           [19,20], [21,22], [23,24], [25,26], [27,28], [29,30],
#           [47,48], [49,50], [53,54], [51,52], [55,56],
#           [37,38], [45,46]]

# colors = [ [0,100,255], [0,100,255], [0,255,255], [0,100,255], [0,255,255], [0,100,255],
#          [0,255,0], [255,200,100], [255,0,255], [0,255,0], [255,200,100], [255,0,255],
#          [0,0,255], [255,0,0], [200,200,0], [255,0,0], [200,200,0], [0,0,0]]

nPoints = 14

keypointsMapping = ['Nose', 'Neck', 'R-Sho', 'R-Elb', 'R-Wr', 'L-Sho', 'L-Elb', 'L-Wr', 'R-Hip', 'R-Knee', 'R-Ank', 'L-Hip', 'L-Knee', 'L-Ank']

POSE_PAIRS = [[1,2], [1,5], [2,3], [3,4], [5,6], [6,7],
              [1,8], [8,9], [9,10], [1,11], [11,12], [12,13],
              [1,0] ]

mapIdx = [[31,32], [39,40], [33,34], [35,36], [41,42], [43,44],
          [19,20], [21,22], [23,24], [25,26], [27,28], [29,30],
          [47,48]]

colors = [ [0,100,255], [0,100,255], [0,255,255], [0,100,255], [0,255,255], [0,100,255],
         [0,255,0], [255,200,100], [255,0,255], [0,255,0], [255,200,100], [255,0,255],
         [0,0,255], [255,0,0]]


def getKeypoints(probMap, threshold=0.1):

    mapSmooth = cv2.GaussianBlur(probMap,(3,3),0,0)

    mapMask = np.uint8(mapSmooth>threshold)
    #np.set_printoptions(threshold=np.inf)
    keypoints = []

    #find the blobs
    contours, hierarchy = cv2.findContours(mapMask, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)

    #for each blob find the maxima
    #对于每个关键点，对confidence map 应用一个阀值（本例采用0.1），生成二值图。
    #首先找出每个关键点区域的全部轮廓。
    #生成这个区域的mask。
    #通过用probMap乘以这个mask，提取该区域的probMap。
    #找到这个区域的本地极大值。要对每个即关键点区域进行处理。
    #本地极大值对应的坐标就是关键点坐标
    for cnt in contours:
        blobMask = np.zeros(mapMask.shape)
        blobMask = cv2.fillConvexPoly(blobMask, cnt, 1)
        maskedProbMap = mapSmooth * blobMask
        _, maxVal, _, maxLoc = cv2.minMaxLoc(maskedProbMap)
        keypoints.append(maxLoc + (probMap[maxLoc[1], maxLoc[0]],))

    return keypoints


# Find valid connections between the different joints of a all persons present
def getValidPairs(output):
    valid_pairs = []
    invalid_pairs = []
    n_interp_samples = 10
    paf_score_th = 0.1
    conf_th = 0.7
    # loop for every POSE_PAIR
    for k in range(len(mapIdx)):
        # A->B constitute a limb
        pafA = output[0, mapIdx[k][0], :, :]
        pafB = output[0, mapIdx[k][1], :, :]
        pafA = cv2.resize(pafA, (frameWidth, frameHeight))
        pafB = cv2.resize(pafB, (frameWidth, frameHeight))

        # candA: (124, 365, 0.17102814, 43)
        #                               detected_keypoints keypoint_id
        # Find the keypoints for the first and second limb
        #把连接对上的关键点提取出来，相同的关键点放一起。把关键点对分开地方到两个列表上
        #（列表名为candA和candB）。在列表candA上的每一个点都会和列表candB上某些点连接
        candA = detected_keypoints[POSE_PAIRS[k][0]]
        candB = detected_keypoints[POSE_PAIRS[k][1]]

        nA = len(candA)
        nB = len(candB)

        # If keypoints for the joint-pair is detected
        # check every joint in candA with every joint in candB
        # Calculate the distance vector between the two joints
        # Find the PAF values at a set of interpolated points between the joints
        # Use the above formula to compute a score to mark the connection valid

        if( nA != 0 and nB != 0):
            valid_pair = np.zeros((0,3))
            for i in range(nA):
                max_j=-1
                maxScore = -1
                found = 0
                for j in range(nB):
                    # Find d_ij
                    d_ij = np.subtract(candB[j][:2], candA[i][:2])
                    norm = np.linalg.norm(d_ij)
                    if norm:
                        d_ij = d_ij / norm
                    else:
                        continue
                    # Find p(u)
                    interp_coord = list(zip(np.linspace(candA[i][0], candB[j][0], num=n_interp_samples),
                                            np.linspace(candA[i][1], candB[j][1], num=n_interp_samples)))
                    # Find L(p(u))
                    paf_interp = []
                    for k in range(len(interp_coord)):
                        paf_interp.append([pafA[int(round(interp_coord[k][1])), int(round(interp_coord[k][0]))],
                                           pafB[int(round(interp_coord[k][1])), int(round(interp_coord[k][0]))] ])
                    # Find E
                    paf_scores = np.dot(paf_interp, d_ij)
                    avg_paf_score = sum(paf_scores)/len(paf_scores)

                    # Check if the connection is valid
                    # If the fraction of interpolated vectors aligned with PAF is higher then threshold -> Valid Pair
                    if ( len(np.where(paf_scores > paf_score_th)[0]) / n_interp_samples ) > conf_th :
                        if avg_paf_score > maxScore:
                            max_j = j
                            maxScore = avg_paf_score
                            found = 1
                # Append the connection to the list
                if found:
                    #   detected_keypoints keypoint_id
                    valid_pair = np.append(valid_pair, [[candA[i][3], candB[max_j][3], maxScore]], axis=0)
            # Append the detected connections to the global list
            valid_pairs.append(valid_pair)
        else: # If no keypoints are detected
            invalid_pairs.append(k)
            valid_pairs.append([])
    return valid_pairs, invalid_pairs



# This function creates a list of keypoints belonging to each person
# For each detected valid pair, it assigns the joint(s) to a person
def getPersonwiseKeypoints(valid_pairs, invalid_pairs):
    # the last number in each row is the overall score

    #我们首先创建空列表，用来存放每个人的关键点（即关键部位）
    personwiseKeypoints = -1 * np.ones((0, 19))
    for k in range(len(mapIdx)):
        if k not in invalid_pairs:
            partAs = valid_pairs[k][:,0]
            partBs = valid_pairs[k][:,1]
            indexA, indexB = np.array(POSE_PAIRS[k])

            for i in range(len(valid_pairs[k])):
                found = 0
                person_idx = -1
                #遍历每一个连接对，检查连接对中的partA是否已经存在于任意列表之中
                for j in range(len(personwiseKeypoints)):
                    if personwiseKeypoints[j][indexA] == partAs[i]:
                        person_idx = j
                        found = 1
                        break

                #如果存在，那么意味着这关键点属于当前列表，同时连接对中的partB也同样属于这个人体
                #把连接对中的partB增加到partA所在的列表。
                if found:
                    personwiseKeypoints[person_idx][indexB] = partBs[i]
                    personwiseKeypoints[person_idx][-1] += keypoints_list[partBs[i].astype(int), 2] + valid_pairs[k][i][2]

                # if find no partA in the subset, create a new subset
                #如果partA不存在于任意列表，那么说明这一对属于一个还没建立列表的人体，于是需要新建一个新列表。
                elif not found and k < 17:
                    row = -1 * np.ones(19)
                    row[indexA] = partAs[i]
                    row[indexB] = partBs[i]
                    # add the keypoint_scores for the two keypoints and the paf_score
                    row[-1] = sum(keypoints_list[valid_pairs[k][i,:2].astype(int), 2]) + valid_pairs[k][i][2]
                    personwiseKeypoints = np.vstack([personwiseKeypoints, row])
    return personwiseKeypoints


inWidth = 368
inHeight = 368

rknn.load_rknn('./pose_deploy_linevec_pre_compile.rknn')
ret = rknn.init_runtime()
if ret != 0:
    print('Init runtime environment failed')
    exit(ret)
print('done')

cap = cv2.VideoCapture(10)

hasFrame, frame = cap.read()
fps = cap.get(cv2.CAP_PROP_FPS)

# to decrease fps
# normal read in 15 fps, 0.05s
# computation time 0.2 s
count = 0
scale = 4

while cv2.waitKey(1) < 0:
    t = time.time()


    hasFrame, frame = cap.read()
    count += 1
    if count is not scale:
        print("count = " + str(count))
        continue
    else:
        count = 0

    # q for end; space for stop
    key = cv2.waitKey(1) & 0xff
    if key == ord(" "):
        cv2.waitKey(0)
    if key == ord("q"):
        break

    # cv2.putText(frame, "FPS {0}".format(float('%.1f' % (counter / (time.time() - start_time)))), (500, 50), cv2.FONT_HERSHEY_SIMPLEX, 2, (0, 0, 255), 3)

    # resize输入图像为368x368
    frame = cv2.resize(frame, (inWidth, inHeight), interpolation=cv2.INTER_CUBIC)
    if not hasFrame:
        cv2.waitKey()
        break
    frameWidth = frame.shape[1]
    frameHeight = frame.shape[0]

    # input mode转为’nchw’
    frame_input = np.transpose(frame, [2, 0, 1])
    t = time.time()
    [output] = rknn.inference(inputs=[frame_input], data_format="nchw")
    print("time:", time.time()-t)
   
    # rknn输出的数组转为1x57x46x46的矩阵
    output = output.reshape(1, 57, 46, 46)
   
    detected_keypoints = []
    keypoints_list = np.zeros((0,3))
    keypoint_id = 0
    threshold = 0.1

    for part in range(nPoints):
        probMap = output[0,part,:,:]
        probMap = cv2.resize(probMap, (frame.shape[1], frame.shape[0]))
        keypoints = getKeypoints(probMap, threshold)
        keypoints_with_id = []
        for i in range(len(keypoints)):
            keypoints_with_id.append(keypoints[i] + (keypoint_id,))
            keypoints_list = np.vstack([keypoints_list, keypoints[i]])
            keypoint_id += 1

        detected_keypoints.append(keypoints_with_id)


    frameClone = frame.copy()
   
    #for i in range(nPoints):
    #   for j in range(len(detected_keypoints[i])):
    #        cv2.circle(frameClone, detected_keypoints[i][j][0:2], 5, colors[i], -1, cv2.LINE_AA)
    #cv2.imshow("Keypoints",frameClone)
   
   
    valid_pairs, invalid_pairs = getValidPairs(output)
    personwiseKeypoints = getPersonwiseKeypoints(valid_pairs, invalid_pairs)
    #连接各个人体关键点
    for i in range(nPoints-1):
        for n in range(len(personwiseKeypoints)):
            index = personwiseKeypoints[n][np.array(POSE_PAIRS[i])]
            if -1 in index:
                continue
            B = np.int32(keypoints_list[index.astype(int), 0])
            A = np.int32(keypoints_list[index.astype(int), 1])
            cv2.line(frameClone, (B[0], A[0]), (B[1], A[1]), colors[i], 3, cv2.LINE_AA)


    cv2.imshow("Detected Pose" , frameClone)
   
    #cv2.waitKey(0)

rknn.release()